// -*- C++ -*-
#ifndef JAMA_SVD_H
#define JAMA_SVD_H

#include "tnt/array1d.h"
#include "tnt/array1d_utils.h"
#include "tnt/array2d.h"
#include "tnt/array2d_utils.h"
#include "tnt/math_utils.h"

#include <algorithm>
// for min(), max() below
#include <cmath>
// for abs() below

// using namespace TNT;
// using namespace std;

namespace JAMA {
/** Singular Value Decomposition.
    <P>
    For an m-by-n matrix A with m >= n, the singular value decomposition is
    an m-by-n orthogonal matrix U, an n-by-n diagonal matrix S, and
    an n-by-n orthogonal matrix V so that A = U*S*V'.
    <P>
    The singular values, sigma[k] = S[k][k], are ordered so that
    sigma[0] >= sigma[1] >= ... >= sigma[n-1].
    <P>
    The singular value decompostion always exists, so the constructor will
    never fail.  The matrix condition number and the effective numerical
    rank can be computed from this decomposition.

    <p>
    (Adapted from JAMA, a Java Matrix Library, developed by jointly
    by the Mathworks and NIST; see  http://math.nist.gov/javanumerics/jama).
*/
template <class Real> class SVD {

  TNT::Array2D<Real> U, V;
  TNT::Array1D<Real> s;
  int m, n;

public:
  SVD(const TNT::Array2D<Real> &Arg) {

    m = Arg.dim1();
    n = Arg.dim2();
    int nu = std::min(m, n);
    s = TNT::Array1D<Real>(std::min(m + 1, n));
    U = TNT::Array2D<Real>(m, nu, Real(0));
    V = TNT::Array2D<Real>(n, n);
    TNT::Array1D<Real> e(n);
    TNT::Array1D<Real> work(m);
    TNT::Array2D<Real> A(Arg.copy());
    int wantu = 1; /* boolean */
    int wantv = 1; /* boolean */
    int i = 0, j = 0, k = 0;

    // Reduce A to bidiagonal form, storing the diagonal elements
    // in s and the super-diagonal elements in e.

    int nct = std::min(m - 1, n);
    int nrt = std::max(0, std::min(n - 2, m));
    for (k = 0; k < std::max(nct, nrt); k++) {
      if (k < nct) {

        // Compute the transformation for the k-th column and
        // place the k-th diagonal in s[k].
        // Compute 2-norm of k-th column without under/overflow.
        s[k] = 0;
        for (i = k; i < m; i++) {
          s[k] = hypot(s[k], A[i][k]);
        }
        if (s[k] != 0.0) {
          if (A[k][k] < 0.0) {
            s[k] = -s[k];
          }
          for (i = k; i < m; i++) {
            A[i][k] /= s[k];
          }
          A[k][k] += 1.0;
        }
        s[k] = -s[k];
      }
      for (j = k + 1; j < n; j++) {
        if ((k < nct) && (s[k] != 0.0)) {

          // Apply the transformation.

          Real t(0.0);
          for (i = k; i < m; i++) {
            t += A[i][k] * A[i][j];
          }
          t = -t / A[k][k];
          for (i = k; i < m; i++) {
            A[i][j] += t * A[i][k];
          }
        }

        // Place the k-th row of A into e for the
        // subsequent calculation of the row transformation.

        e[j] = A[k][j];
      }
      if (wantu & (k < nct)) {

        // Place the transformation in U for subsequent back
        // multiplication.

        for (i = k; i < m; i++) {
          U[i][k] = A[i][k];
        }
      }
      if (k < nrt) {

        // Compute the k-th row transformation and place the
        // k-th super-diagonal in e[k].
        // Compute 2-norm without under/overflow.
        e[k] = 0;
        for (i = k + 1; i < n; i++) {
          e[k] = hypot(e[k], e[i]);
        }
        if (e[k] != 0.0) {
          if (e[k + 1] < 0.0) {
            e[k] = -e[k];
          }
          for (i = k + 1; i < n; i++) {
            e[i] /= e[k];
          }
          e[k + 1] += 1.0;
        }
        e[k] = -e[k];
        if ((k + 1 < m) & (e[k] != 0.0)) {

          // Apply the transformation.

          for (i = k + 1; i < m; i++) {
            work[i] = 0.0;
          }
          for (j = k + 1; j < n; j++) {
            for (i = k + 1; i < m; i++) {
              work[i] += e[j] * A[i][j];
            }
          }
          for (j = k + 1; j < n; j++) {
            Real t(-e[j] / e[k + 1]);
            for (i = k + 1; i < m; i++) {
              A[i][j] += t * work[i];
            }
          }
        }
        if (wantv) {

          // Place the transformation in V for subsequent
          // back multiplication.

          for (i = k + 1; i < n; i++) {
            V[i][k] = e[i];
          }
        }
      }
    }

    // Set up the final bidiagonal matrix or order p.

    int p = std::min(n, m + 1);
    if (nct < n) {
      s[nct] = A[nct][nct];
    }
    if (m < p) {
      s[p - 1] = 0.0;
    }
    if (nrt + 1 < p) {
      e[nrt] = A[nrt][p - 1];
    }
    e[p - 1] = 0.0;

    // If required, generate U.

    if (wantu) {
      for (j = nct; j < nu; j++) {
        for (i = 0; i < m; i++) {
          U[i][j] = 0.0;
        }
        U[j][j] = 1.0;
      }
      for (k = nct - 1; k >= 0; k--) {
        if (s[k] != 0.0) {
          for (j = k + 1; j < nu; j++) {
            Real t(0.0);
            for (i = k; i < m; i++) {
              t += U[i][k] * U[i][j];
            }
            t = -t / U[k][k];
            for (i = k; i < m; i++) {
              U[i][j] += t * U[i][k];
            }
          }
          for (i = k; i < m; i++) {
            U[i][k] = -U[i][k];
          }
          U[k][k] = 1.0 + U[k][k];
          for (i = 0; i < k - 1; i++) {
            U[i][k] = 0.0;
          }
        } else {
          for (i = 0; i < m; i++) {
            U[i][k] = 0.0;
          }
          U[k][k] = 1.0;
        }
      }
    }

    // If required, generate V.

    if (wantv) {
      for (k = n - 1; k >= 0; k--) {
        if ((k < nrt) & (e[k] != 0.0)) {
          for (j = k + 1; j < nu; j++) {
            Real t(0.0);
            for (i = k + 1; i < n; i++) {
              t += V[i][k] * V[i][j];
            }
            t = -t / V[k + 1][k];
            for (i = k + 1; i < n; i++) {
              V[i][j] += t * V[i][k];
            }
          }
        }
        for (i = 0; i < n; i++) {
          V[i][k] = 0.0;
        }
        V[k][k] = 1.0;
      }
    }

    // Main iteration loop for the singular values.

    int pp = p - 1;
    int iter = 0;
    Real eps(pow(2.0, -52.0));
    while (p > 0) {
      int k = 0;
      int kase = 0;

      // Here is where a test for too many iterations would go.

      // This section of the program inspects for
      // negligible elements in the s and e arrays.  On
      // completion the variables kase and k are set as follows.

      // kase = 1     if s(p) and e[k-1] are negligible and k<p
      // kase = 2     if s(k) is negligible and k<p
      // kase = 3     if e[k-1] is negligible, k<p, and
      //              s(k), ..., s(p) are not negligible (qr step).
      // kase = 4     if e(p-1) is negligible (convergence).

      for (k = p - 2; k >= -1; k--) {
        if (k == -1) {
          break;
        }
        if (std::abs(e[k]) <= eps * (std::abs(s[k]) + std::abs(s[k + 1]))) {
          e[k] = 0.0;
          break;
        }
      }
      if (k == p - 2) {
        kase = 4;
      } else {
        int ks;
        for (ks = p - 1; ks >= k; ks--) {
          if (ks == k) {
            break;
          }
          Real t((ks != p ? std::abs(e[ks]) : 0.) +
                 (ks != k + 1 ? std::abs(e[ks - 1]) : 0.));
          if (std::abs(s[ks]) <= eps * t) {
            s[ks] = 0.0;
            break;
          }
        }
        if (ks == k) {
          kase = 3;
        } else if (ks == p - 1) {
          kase = 1;
        } else {
          kase = 2;
          k = ks;
        }
      }
      k++;

      // Perform the task indicated by kase.

      switch (kase) {

      // Deflate negligible s(p).

      case 1: {
        Real f(e[p - 2]);
        e[p - 2] = 0.0;
        for (j = p - 2; j >= k; j--) {
          Real t(hypot(s[j], f));
          Real cs(s[j] / t);
          Real sn(f / t);
          s[j] = t;
          if (j != k) {
            f = -sn * e[j - 1];
            e[j - 1] = cs * e[j - 1];
          }
          if (wantv) {
            for (i = 0; i < n; i++) {
              t = cs * V[i][j] + sn * V[i][p - 1];
              V[i][p - 1] = -sn * V[i][j] + cs * V[i][p - 1];
              V[i][j] = t;
            }
          }
        }
      } break;

      // Split at negligible s(k).

      case 2: {
        Real f(e[k - 1]);
        e[k - 1] = 0.0;
        for (j = k; j < p; j++) {
          Real t(hypot(s[j], f));
          Real cs(s[j] / t);
          Real sn(f / t);
          s[j] = t;
          f = -sn * e[j];
          e[j] = cs * e[j];
          if (wantu) {
            for (i = 0; i < m; i++) {
              t = cs * U[i][j] + sn * U[i][k - 1];
              U[i][k - 1] = -sn * U[i][j] + cs * U[i][k - 1];
              U[i][j] = t;
            }
          }
        }
      } break;

      // Perform one qr step.

      case 3: {

        // Calculate the shift.

        Real scale = std::max(
            std::max(std::max(std::max(std::abs(s[p - 1]), std::abs(s[p - 2])),
                              std::abs(e[p - 2])),
                     std::abs(s[k])),
            std::abs(e[k]));
        Real sp = s[p - 1] / scale;
        Real spm1 = s[p - 2] / scale;
        Real epm1 = e[p - 2] / scale;
        Real sk = s[k] / scale;
        Real ek = e[k] / scale;
        Real b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / 2.0;
        Real c = (sp * epm1) * (sp * epm1);
        Real shift = 0.0;
        if ((b != 0.0) || (c != 0.0)) {
          shift = std::sqrt(b * b + c);
          if (b < 0.0) {
            shift = -shift;
          }
          shift = c / (b + shift);
        }
        Real f = (sk + sp) * (sk - sp) + shift;
        Real g = sk * ek;

        // Chase zeros.

        for (j = k; j < p - 1; j++) {
          Real t = hypot(f, g);
          Real cs = f / t;
          Real sn = g / t;
          if (j != k) {
            e[j - 1] = t;
          }
          f = cs * s[j] + sn * e[j];
          e[j] = cs * e[j] - sn * s[j];
          g = sn * s[j + 1];
          s[j + 1] = cs * s[j + 1];
          if (wantv) {
            for (i = 0; i < n; i++) {
              t = cs * V[i][j] + sn * V[i][j + 1];
              V[i][j + 1] = -sn * V[i][j] + cs * V[i][j + 1];
              V[i][j] = t;
            }
          }
          t = hypot(f, g);
          cs = f / t;
          sn = g / t;
          s[j] = t;
          f = cs * e[j] + sn * s[j + 1];
          s[j + 1] = -sn * e[j] + cs * s[j + 1];
          g = sn * e[j + 1];
          e[j + 1] = cs * e[j + 1];
          if (wantu && (j < m - 1)) {
            for (i = 0; i < m; i++) {
              t = cs * U[i][j] + sn * U[i][j + 1];
              U[i][j + 1] = -sn * U[i][j] + cs * U[i][j + 1];
              U[i][j] = t;
            }
          }
        }
        e[p - 2] = f;
        iter = iter + 1;
      } break;

      // Convergence.

      case 4: {

        // Make the singular values positive.

        if (s[k] <= 0.0) {
          s[k] = (s[k] < 0.0 ? -s[k] : 0.0);
          if (wantv) {
            for (i = 0; i <= pp; i++) {
              V[i][k] = -V[i][k];
            }
          }
        }

        // Order the singular values.

        while (k < pp) {
          if (s[k] >= s[k + 1]) {
            break;
          }
          Real t = s[k];
          s[k] = s[k + 1];
          s[k + 1] = t;
          if (wantv && (k < n - 1)) {
            for (i = 0; i < n; i++) {
              t = V[i][k + 1];
              V[i][k + 1] = V[i][k];
              V[i][k] = t;
            }
          }
          if (wantu && (k < m - 1)) {
            for (i = 0; i < m; i++) {
              t = U[i][k + 1];
              U[i][k + 1] = U[i][k];
              U[i][k] = t;
            }
          }
          k++;
        }
        iter = 0;
        p--;
      } break;
      }
    }
  }

  void getU(TNT::Array2D<Real> &A) {
    int minm = std::min(m + 1, n);

    A = TNT::Array2D<Real>(m, minm);

    for (int i = 0; i < m; i++)
      for (int j = 0; j < minm; j++)
        A[i][j] = U[i][j];
  }

  /* Return the right singular vectors */

  void getV(TNT::Array2D<Real> &A) { A = V; }

  /** Return the one-dimensional array of singular values */

  void getSingularValues(TNT::Array1D<Real> &x) { x = s; }

  /** Return the diagonal matrix of singular values
      @return     S
  */

  void getS(TNT::Array2D<Real> &A) {
    A = TNT::Array2D<Real>(n, n);
    for (int i = 0; i < n; i++) {
      for (int j = 0; j < n; j++) {
        A[i][j] = 0.0;
      }
      A[i][i] = s[i];
    }
  }

  /** Two norm  (std::max(S)) */

  Real norm2() { return s[0]; }

  /** Two norm of condition number (std::max(S)/std::min(S)) */

  Real cond() { return s[0] / s[std::min(m, n) - 1]; }

  /** Effective numerical matrix rank
      @return     Number of nonnegligible singular values.
  */

  int rank() {
    Real eps = pow(2.0, -52.0);
    Real tol = std::max(m, n) * s[0] * eps;
    int r = 0;
    for (int i = 0; i < s.dim(); i++) {
      if (s[i] > tol) {
        r++;
      }
    }
    return r;
  }
};
}
#endif
// JAMA_SVD_H
